import os
import sys
import pickle

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import T5TokenizerFast, T5ForConditionalGeneration, T5Tokenizer
import numpy as np
from tqdm import tqdm


from my_model import DualObjectFMRItoT5
from creat_data import load_all_data


def train_fmri_to_text_model(model, train_dataloader, val_dataloader, num_epochs=40, stage_1_epochs=30,
                             learning_rate=1e-5, device="cuda", tokenizer=None):
    # Training stages
    stages = [
        {"name": "Stage 1: Adapters only", "epochs": stage_1_epochs, "lr": learning_rate},
        {"name": "Stage 2: All layers", "epochs": num_epochs - stage_1_epochs, "lr": learning_rate/10}
    ]

    model = model.to(device)

    global_epoch = 0
    for stage_idx, stage in enumerate(stages):
        print(f"\n===== Starting {stage['name']} =====")
        if stage_idx == 0:
            # Stage 1: Just train adapters and attention (default state after init)
            pass
        elif stage_idx == 1:
            # Stage 2: Unfreeze everything
            model.unfreeze_all()

        # check_frozen(model)
        trainable_params = [p for p in model.parameters() if p.requires_grad]
        optimizer = optim.AdamW(trainable_params, lr=stage["lr"])

        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="min", factor=0.5, patience=2, verbose=True
        )

        best_val_loss = float("inf")
        best_model_state = None

        for epoch in range(stage["epochs"]):
            global_epoch += 1

            # Training phase
            model.train()
            train_loss = 0
            train_bar = tqdm(train_dataloader,
                             desc=f"Epoch {global_epoch}/{num_epochs} [Train]")

            for batch in train_bar:
                # Get batch data
                fmri_vectors = batch["fmri"].to(device)
                labels = [label.to(device) for label in batch["labels"]]

                # Forward pass

                outputs = model(fmri_vectors, labels=labels)
                # loss = outputs.loss
                loss = sum(output.loss for output in outputs) / len(outputs)

                optimizer.zero_grad()
                loss.backward()

                # Gradient clipping to prevent exploding gradients
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

                optimizer.step()

                train_loss += loss.item()
                train_bar.set_postfix(loss=train_loss / (train_bar.n + 1))

            avg_train_loss = train_loss / len(train_dataloader)

            model.eval()
            val_loss = 0
            generated_sentences = []
            reference_sentences = []

            val_bar = tqdm(val_dataloader,
                           desc=f"Epoch {global_epoch}/{num_epochs} [Val]")

            with torch.no_grad():
                for batch_idx, batch in enumerate(val_bar):
                    fmri_vectors = batch["fmri"].to(device)
                    labels = [label.to(device) for label in batch["labels"]]

                    outputs = model(fmri_vectors, labels=labels)
                    batch_loss = sum(output.loss.item() for output in outputs) / len(outputs)
                    val_loss += batch_loss

                    if batch_idx < 2:
                        for i in range(min(2, fmri_vectors.size(0))):  # Take first 2 examples from batch
                            single_fmri = fmri_vectors[i:i + 1]
                            generated_ids = model.generate_sentences(single_fmri, max_length=50)
                            sample_generated = []
                            for ids in generated_ids:
                                decoded = tokenizer.decode(ids[0], skip_special_tokens=True)
                                sample_generated.append(decoded)

                            sample_reference = [text[i] for text in batch["text"]]

                            generated_sentences.append(sample_generated)
                            reference_sentences.append(sample_reference)

                    val_bar.set_postfix(loss=val_loss / (val_bar.n + 1))

            avg_val_loss = val_loss / len(val_dataloader)
            print(f"Epoch {global_epoch}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

            for i in range(min(3, len(generated_sentences))):
                print(f"\nExample {i + 1}:")
                for j in range(len(generated_sentences[i])):
                    print(f"Segment {j + 1}:")
                    print(f"  Generated: {generated_sentences[i][j]}")
                    print(f"  Reference: {reference_sentences[i][j]}")
                print("-" * 80)

            scheduler.step(avg_val_loss)

            # Save the best model
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                best_model_state = model.state_dict().copy()
                print(f"New best model with validation loss: {best_val_loss:.4f}")

    # Load the best model state
    # model.load_state_dict(best_model_state)
    return model


def main():

    tokenizer = T5TokenizerFast.from_pretrained("t5-base")

    train_dataset, test_dataset, _ = load_all_data()
    train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
    val_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False)

    model = DualObjectFMRItoT5(fmri_dim=15724, t5_model_name="t5-base", num_tokens=5)

    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    trained_model = train_fmri_to_text_model(
        model, train_dataloader, val_dataloader,
        num_epochs=60, learning_rate=1e-5, device=device, stage_1_epochs=45 ,tokenizer=tokenizer
    )

    # Save
    model_save_dir = "model_save_dir"
    torch.save(trained_model.state_dict(), f"{model_save_dir}/fmri_to_text_model_abl.pt")
    print("Training complete and model saved!")


if __name__ == "__main__":
    main()
